import json
from collections.abc import Generator
from collections.abc import Mapping
from collections.abc import Sequence
from typing import Any
from typing import cast

from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.citation_processor import DynamicCitationProcessor
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import LlmStepResult
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import ChatFileType
from onyx.llm.interfaces import LanguageModelInput
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.message_types import AssistantMessage
from onyx.llm.message_types import ChatCompletionMessage
from onyx.llm.message_types import ImageContentPart
from onyx.llm.message_types import SystemMessage
from onyx.llm.message_types import TextContentPart
from onyx.llm.message_types import ToolCall
from onyx.llm.message_types import ToolMessage
from onyx.llm.message_types import UserMessageWithParts
from onyx.llm.message_types import UserMessageWithText
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningDone
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.tools.models import ToolCallKickoff
from onyx.tracing.framework.create import generation_span
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger


logger = setup_logger()


TOOL_CALL_MSG_FUNC_NAME = "function_name"
TOOL_CALL_MSG_ARGUMENTS = "arguments"


def _format_message_history_for_logging(
    message_history: LanguageModelInput,
) -> str:
    """Format message history for logging, with special handling for tool calls.

    Tool calls are formatted as JSON with 4-space indentation for readability.
    """
    formatted_lines = []

    separator = "================================================"

    # Handle string input
    if isinstance(message_history, str):
        formatted_lines.append("Message [string]:")
        formatted_lines.append(separator)
        formatted_lines.append(f"{message_history}")
        return "\n".join(formatted_lines)

    # Handle sequence of messages
    for i, msg in enumerate(message_history):
        # Type guard: ensure msg is a dict-like object (TypedDict)
        if not isinstance(msg, dict):
            formatted_lines.append(f"Message {i + 1} [unknown]:")
            formatted_lines.append(separator)
            formatted_lines.append(f"{msg}")
            if i < len(message_history) - 1:
                formatted_lines.append(separator)
            continue

        role = msg.get("role", "unknown")
        formatted_lines.append(f"Message {i + 1} [{role}]:")
        formatted_lines.append(separator)

        if role == "system":
            content = msg.get("content", "")
            if isinstance(content, str):
                formatted_lines.append(f"{content}")

        elif role == "user":
            content = msg.get("content", "")
            if isinstance(content, str):
                formatted_lines.append(f"{content}")
            elif isinstance(content, list):
                # Handle multimodal content (text + images)
                for part in content:
                    if isinstance(part, dict):
                        part_type = part.get("type")
                        if part_type == "text":
                            text = part.get("text", "")
                            if isinstance(text, str):
                                formatted_lines.append(f"{text}")
                        elif part_type == "image_url":
                            image_url_dict = part.get("image_url")
                            if isinstance(image_url_dict, dict):
                                url = image_url_dict.get("url", "")
                                if isinstance(url, str):
                                    formatted_lines.append(f"[Image: {url[:50]}...]")

        elif role == "assistant":
            content = msg.get("content")
            if content and isinstance(content, str):
                formatted_lines.append(f"{content}")

            tool_calls = msg.get("tool_calls")
            if tool_calls and isinstance(tool_calls, list):
                formatted_lines.append("Tool calls:")
                for tool_call in tool_calls:
                    if isinstance(tool_call, dict):
                        tool_call_dict: dict[str, Any] = {}
                        tool_call_id = tool_call.get("id")
                        tool_call_type = tool_call.get("type")
                        function_dict = tool_call.get("function")

                        if tool_call_id:
                            tool_call_dict["id"] = tool_call_id
                        if tool_call_type:
                            tool_call_dict["type"] = tool_call_type
                        if isinstance(function_dict, dict):
                            tool_call_dict["function"] = {
                                "name": function_dict.get("name", ""),
                                "arguments": function_dict.get("arguments", ""),
                            }

                        tool_call_json = json.dumps(tool_call_dict, indent=4)
                        formatted_lines.append(tool_call_json)

        elif role == "tool":
            content = msg.get("content", "")
            tool_call_id = msg.get("tool_call_id", "")
            if isinstance(content, str) and isinstance(tool_call_id, str):
                formatted_lines.append(f"Tool call ID: {tool_call_id}")
                formatted_lines.append(f"Response: {content}")

        # Add separator before next message (or at end)
        if i < len(message_history) - 1:
            formatted_lines.append(separator)

    return "\n".join(formatted_lines)


def _update_tool_call_with_delta(
    tool_calls_in_progress: dict[int, dict[str, Any]],
    tool_call_delta: Any,
) -> None:
    index = tool_call_delta.index

    if index not in tool_calls_in_progress:
        tool_calls_in_progress[index] = {
            "id": None,
            "name": None,
            "arguments": "",
        }

    if tool_call_delta.id:
        tool_calls_in_progress[index]["id"] = tool_call_delta.id

    if tool_call_delta.function:
        if tool_call_delta.function.name:
            tool_calls_in_progress[index]["name"] = tool_call_delta.function.name

        if tool_call_delta.function.arguments:
            tool_calls_in_progress[index][
                "arguments"
            ] += tool_call_delta.function.arguments


def _extract_tool_call_kickoffs(
    id_to_tool_call_map: dict[int, dict[str, Any]],
) -> list[ToolCallKickoff]:
    """Extract ToolCallKickoff objects from the tool call map.

    Returns a list of ToolCallKickoff objects for valid tool calls (those with both id and name).
    """
    tool_calls: list[ToolCallKickoff] = []
    for tool_call_data in id_to_tool_call_map.values():
        if tool_call_data.get("id") and tool_call_data.get("name"):
            try:
                # Parse arguments JSON string to dict
                tool_args = (
                    json.loads(tool_call_data["arguments"])
                    if tool_call_data["arguments"]
                    else {}
                )
            except json.JSONDecodeError:
                # If parsing fails, try empty dict, most tools would fail though
                logger.error(
                    f"Failed to parse tool call arguments: {tool_call_data['arguments']}"
                )
                tool_args = {}

            tool_calls.append(
                ToolCallKickoff(
                    tool_call_id=tool_call_data["id"],
                    tool_name=tool_call_data["name"],
                    tool_args=tool_args,
                )
            )
    return tool_calls


def translate_history_to_llm_format(
    history: list[ChatMessageSimple],
) -> LanguageModelInput:
    """Convert a list of ChatMessageSimple to LanguageModelInput format.

    Converts ChatMessageSimple messages to ChatCompletionMessage format,
    handling different message types and image files for multimodal support.
    """
    messages: list[ChatCompletionMessage] = []

    for msg in history:
        if msg.message_type == MessageType.SYSTEM:
            system_msg: SystemMessage = {
                "role": "system",
                "content": msg.message,
            }
            messages.append(system_msg)

        elif msg.message_type == MessageType.USER:
            # Handle user messages with potential images
            if msg.image_files:
                # Build content parts: text + images
                content_parts: list[TextContentPart | ImageContentPart] = [
                    {"type": "text", "text": msg.message}
                ]

                # Add image parts
                for img_file in msg.image_files:
                    if img_file.file_type == ChatFileType.IMAGE:
                        try:
                            image_type = get_image_type_from_bytes(img_file.content)
                            base64_data = img_file.to_base64()
                            image_url = f"data:{image_type};base64,{base64_data}"

                            image_part: ImageContentPart = {
                                "type": "image_url",
                                "image_url": {"url": image_url},
                            }
                            content_parts.append(image_part)
                        except Exception as e:
                            logger.warning(
                                f"Failed to process image file {img_file.file_id}: {e}. "
                                "Skipping image."
                            )

                user_msg_with_parts: UserMessageWithParts = {
                    "role": "user",
                    "content": content_parts,
                }
                messages.append(user_msg_with_parts)
            else:
                # Simple text-only user message
                user_msg_text: UserMessageWithText = {
                    "role": "user",
                    "content": msg.message,
                }
                messages.append(user_msg_text)

        elif msg.message_type == MessageType.ASSISTANT:
            assistant_msg: AssistantMessage = {
                "role": "assistant",
                "content": msg.message or None,
            }
            messages.append(assistant_msg)

        elif msg.message_type == MessageType.TOOL_CALL:
            # Tool calls are represented as Assistant Messages with tool_calls field
            # Try to reconstruct tool call structure if we have tool_call_id
            tool_calls: list[ToolCall] = []
            if msg.tool_call_id:
                try:
                    # Parse the message content (which should contain function_name and arguments)
                    tool_call_data = json.loads(msg.message) if msg.message else {}

                    if (
                        isinstance(tool_call_data, dict)
                        and TOOL_CALL_MSG_FUNC_NAME in tool_call_data
                    ):
                        function_name = tool_call_data.get(
                            TOOL_CALL_MSG_FUNC_NAME, "unknown"
                        )
                        tool_args = tool_call_data.get(TOOL_CALL_MSG_ARGUMENTS, {})
                    else:
                        function_name = "unknown"
                        tool_args = (
                            tool_call_data if isinstance(tool_call_data, dict) else {}
                        )

                    # NOTE: if the model is trained on a different tool call format, this may slightly interfere
                    # with the future tool calls, if it doesn't look like this. Almost certainly not a big deal.
                    tool_call: ToolCall = {
                        "id": msg.tool_call_id,
                        "type": "function",
                        "function": {
                            "name": function_name,
                            "arguments": json.dumps(tool_args) if tool_args else "{}",
                        },
                    }
                    tool_calls.append(tool_call)
                except (json.JSONDecodeError, ValueError) as e:
                    logger.warning(
                        f"Failed to parse tool call data for tool_call_id {msg.tool_call_id}: {e}. "
                        "Including as content-only message."
                    )

            assistant_msg_with_tool: AssistantMessage = {
                "role": "assistant",
                "content": None,  # The tool call is parsed, doesn't need to be duplicated in the content
            }
            if tool_calls:
                assistant_msg_with_tool["tool_calls"] = tool_calls
            messages.append(assistant_msg_with_tool)

        elif msg.message_type == MessageType.TOOL_CALL_RESPONSE:
            if not msg.tool_call_id:
                raise ValueError(
                    f"Tool call response message encountered but tool_call_id is not available. Message: {msg}"
                )

            tool_msg: ToolMessage = {
                "role": "tool",
                "content": msg.message,
                "tool_call_id": msg.tool_call_id,
            }
            messages.append(tool_msg)

        else:
            logger.warning(
                f"Unknown message type {msg.message_type} in history. Skipping message."
            )

    return messages


def run_llm_step(
    history: list[ChatMessageSimple],
    tool_definitions: list[dict],
    tool_choice: ToolChoiceOptions,
    llm: LLM,
    turn_index: int,
    citation_processor: DynamicCitationProcessor,
    state_container: ChatStateContainer,
    final_documents: list[SearchDoc] | None = None,
) -> Generator[Packet, None, tuple[LlmStepResult, int]]:
    # The second return value is for the turn index because reasoning counts on the frontend as a turn
    # TODO this is maybe ok but does not align well with the backend logic too well
    llm_msg_history = translate_history_to_llm_format(history)

    # Uncomment the line below to log the entire message history to the console
    if LOG_ONYX_MODEL_INTERACTIONS:
        logger.info(
            f"Message history:\n{_format_message_history_for_logging(llm_msg_history)}"
        )

    id_to_tool_call_map: dict[int, dict[str, Any]] = {}
    reasoning_start = False
    answer_start = False
    accumulated_reasoning = ""
    accumulated_answer = ""

    with generation_span(
        model=llm.config.model_name,
        model_config={
            "base_url": str(llm.config.api_base or ""),
            "model_impl": "litellm",
        },
    ) as span_generation:
        span_generation.span_data.input = cast(
            Sequence[Mapping[str, Any]], llm_msg_history
        )
        for packet in llm.stream(
            prompt=llm_msg_history,
            tools=tool_definitions,
            tool_choice=tool_choice,
            structured_response_format=None,  # TODO
        ):
            if packet.usage:
                usage = packet.usage
                span_generation.span_data.usage = {
                    "input_tokens": usage.prompt_tokens,
                    "output_tokens": usage.completion_tokens,
                    "cache_read_input_tokens": usage.cache_read_input_tokens,
                    "cache_creation_input_tokens": usage.cache_creation_input_tokens,
                }
            delta = packet.choice.delta

            # Should only happen once, frontend does not expect multiple
            # ReasoningStart or ReasoningDone packets.
            if delta.reasoning_content:
                accumulated_reasoning += delta.reasoning_content
                # Save reasoning incrementally to state container
                state_container.set_reasoning_tokens(accumulated_reasoning)
                if not reasoning_start:
                    yield Packet(
                        turn_index=turn_index,
                        obj=ReasoningStart(),
                    )
                yield Packet(
                    turn_index=turn_index,
                    obj=ReasoningDelta(reasoning=delta.reasoning_content),
                )
                reasoning_start = True

            if delta.content:
                if reasoning_start:
                    yield Packet(
                        turn_index=turn_index,
                        obj=ReasoningDone(),
                    )
                    turn_index += 1
                    reasoning_start = False

                if not answer_start:
                    yield Packet(
                        turn_index=turn_index,
                        obj=AgentResponseStart(
                            final_documents=final_documents,
                        ),
                    )
                    answer_start = True

                for result in citation_processor.process_token(delta.content):
                    if isinstance(result, str):
                        accumulated_answer += result
                        # Save answer incrementally to state container
                        state_container.set_answer_tokens(accumulated_answer)
                        yield Packet(
                            turn_index=turn_index,
                            obj=AgentResponseDelta(content=result),
                        )
                    elif isinstance(result, CitationInfo):
                        yield Packet(
                            turn_index=turn_index,
                            obj=result,
                        )

            if delta.tool_calls:
                if reasoning_start:
                    yield Packet(
                        turn_index=turn_index,
                        obj=ReasoningDone(),
                    )
                    turn_index += 1
                    reasoning_start = False

                for tool_call_delta in delta.tool_calls:
                    _update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)

        tool_calls = _extract_tool_call_kickoffs(id_to_tool_call_map)
        if tool_calls:
            tool_calls_list: list[ToolCall] = [
                {
                    "id": kickoff.tool_call_id,
                    "type": "function",
                    "function": {
                        "name": kickoff.tool_name,
                        "arguments": json.dumps(kickoff.tool_args),
                    },
                }
                for kickoff in tool_calls
            ]

            assistant_msg: AssistantMessage = {
                "role": "assistant",
                "content": accumulated_answer if accumulated_answer else None,
                "tool_calls": tool_calls_list,
            }
            span_generation.span_data.output = [assistant_msg]
        elif accumulated_answer:
            span_generation.span_data.output = [
                {"role": "assistant", "content": accumulated_answer}
            ]
    # Close reasoning block if still open (stream ended with reasoning content)
    if reasoning_start:
        yield Packet(
            turn_index=turn_index,
            obj=ReasoningDone(),
        )
        turn_index += 1

    # Flush any remaining content from citation processor
    if citation_processor:
        for result in citation_processor.process_token(None):
            if isinstance(result, str):
                accumulated_answer += result
                # Save answer incrementally to state container
                state_container.set_answer_tokens(accumulated_answer)
                yield Packet(
                    turn_index=turn_index,
                    obj=AgentResponseDelta(content=result),
                )
            elif isinstance(result, CitationInfo):
                yield Packet(
                    turn_index=turn_index,
                    obj=result,
                )

    # Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
    # Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)
    if LOG_ONYX_MODEL_INTERACTIONS:
        logger.debug(f"Accumulated reasoning: {accumulated_reasoning}")
        logger.debug(f"Accumulated answer: {accumulated_answer}")

    if tool_calls:
        tool_calls_str = "\n".join(
            f"  - {tc.tool_name}: {json.dumps(tc.tool_args, indent=4)}"
            for tc in tool_calls
        )
        logger.debug(f"Tool calls:\n{tool_calls_str}")
    else:
        logger.debug("Tool calls: []")

    return (
        LlmStepResult(
            reasoning=accumulated_reasoning if accumulated_reasoning else None,
            answer=accumulated_answer if accumulated_answer else None,
            tool_calls=tool_calls if tool_calls else None,
        ),
        turn_index,
    )
